"""
coopernaut receiver policies
"""
import os
import yaml
import math
import torch
import random
import argparse
import numpy as np
from omegaconf import OmegaConf
from models.point_transformer import PointTransformer
from models.cooperative_point_transformer import CooperativePointTransformer

from policies.utils.episode_memory import EpisodeMemory
from policies.base_policy import BasePolicy
from policies.utils.coopernaut_utils import get_config
from policies.utils.lidar_processor import LidarProcessorConfig, filter_lidar_by_boundary, lidar_to_bev_v2, pad_or_truncate, pc_to_car_alignment, Sparse_Quantize, TransformMatrix_WorldCoords

class CoopernautSenderPolicy(BasePolicy):
    def __init__(self, agent_id, config_file):
        self.agent_id = agent_id
        self.episode_return = 0
        self.episode_memory = EpisodeMemory()
        self.config, self.model_path = get_config(config_file, num_checkpoint=105)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        assert 'cpt' in self.model_path    
        self.model = CooperativePointTransformer(self.config).to(self.device)
        self.model.eval()
        self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
        self.max_num_neighbors = self.config.get('max_num_neighbors')
        self.npoints = self.config.npoints
        self.step_count = 0

    def reset(self):
        self.step_count = 0
        self.episode_return = 0

    def observe(self, obs, reward, terminated, truncated, info):
        self.observation = obs
        self.episode_return += reward

    def act(self):
        lidar = self.observation["LIDAR"][1][:,:3]
        transform = self.observation["transform"]
        bbox = self.observation["bbox"]
        speed = self.observation["speed"] / 30.0
        # generate representation for sending
        lidar, transform = self.process_lidar(lidar, transform, bbox)
        representation, xyz, _ = self.model.backbone_other(lidar)
        representation = representation.detach().cpu().numpy().tolist()
        xyz = xyz.detach().cpu().numpy().tolist()
        transform = transform.tolist()
        return {"representation": representation,
                "xyz": xyz,
                "transform": transform
                }

    def get_episode_return(self):
        return self.episode_return

    def record_episode(self):
        pass

    def learn(self):
        pass

    def process_lidar(self, lidar, transform, bbox):
        z_compensation = 2*abs(bbox.extent.z) + 0.5 # LidarRoofTopDistance
        transform.location.z = 0
        transform = TransformMatrix_WorldCoords(transform)
        transform = np.array(transform.matrix)
        lidar[:,2] = lidar[:,2] + abs(z_compensation)
        lidar = pc_to_car_alignment(lidar)
        lidar = Sparse_Quantize(lidar)
        lidar = np.unique(lidar, axis=0)
        lidar = pad_or_truncate(lidar, self.npoints)
        lidar = torch.from_numpy(np.array([lidar])).float().to(self.device)
        return lidar, transform

def parse_action(control):
    throttle, brake, steer = control
    throttle = throttle.cpu().detach().numpy()
    brake = brake.cpu().detach().numpy()
    steer = steer.cpu().detach().numpy()
    action = {}
    action['steer'] = np.clip(steer, -1.0, 1.0)
    action['throttle'] = np.clip(throttle, 0.0, 1.0)
    action['brake'] = np.clip(brake, 0.0, 1.0)
    print(action)
    return action

def process_messages(messages):
    return []
